library(stargazer)
library(survival)
library(MASS)
library(ggplot2)
library(ggthemes)
library(reshape2)
library(tidyverse)
source("utility_functions_WRVC.R")

## Functions to make bootstrapping of Conflict Intervals easier
conint <- function(x, level = 0.05) {
  quantile(x[,1], c(level, 1-level))
}

conf_int <- function(x) {
  timediv <- split(x, x[,2])
  conf <- lapply(timediv, conint)
  out <- data.frame(low = NA, high = NA, time = bchm1[,2])
  for (o in 1:length(conf)) {
    out[o, 1:2] <- conf[[o]] 
  }
  return(out)
}

## A function to estimate the cumulative incidence function from the competing risk model for given values of independent variables
cif <- function(fit1, fit2, cov){
  cum1 <- basehaz(fit1, centered = TRUE)
  cum2 <- basehaz(fit2, centered = TRUE)
  
    bhaz.t1 <- cum1[,1]
    for (i in 1:(length(cum1[,2]) - 1)) {
      bhaz.t1[i+1] <- cum1[i+1, 1] - cum1[i, 1]
    }
    bhaz.t2 <- cum2[,1]
    for (i in 1:(length(cum2[,2]) - 1)) {
      bhaz.t2[i+1] <- cum2[i+1, 1] - cum2[i, 1]
    }
    haz.t1 <- bhaz.t1 * predict(fit1, newdata = cov, type = "risk")
    haz.t2 <- bhaz.t2 * predict(fit2, newdata = cov, type = "risk")
    
  sf <- exp(-(cumsum(haz.t1)+cumsum(haz.t2)))
  sfl <- c(1,sf[-length(sf)])
  
  cif1 <- cumsum(haz.t1*sf)
  cif2 <- cumsum(haz.t2*sf)
  
  list(cif1, cif2)
}

## Running the main model (See "Main_replication.R")
dat <- read_csv("Who_Restarts_Violent_Conflict_Replication.csv")
main_dat <- arrangement(dat, preact_def = 2, recast_def = 1, both = 1)
Full_Model <- "Num_Actor + Revolution + Res_Conf + PKO + Mountain + lag_IMR + ELF + Victory + lag_Growth + Agreement + Post_CW"
main_res <- competing_risks(main_dat, indep = Full_Model, subest = T)
compdat <- model_extract(main_dat, Full_Model)
unit_data <- subset(compdat, compdat$t_time==1)

## Median values of independent variables will be used for prediction
first_cov <- unit_data[,which(names(unit_data) %in% names(main_res[[4]]$coefficients))]
med <- apply(first_cov, median, MARGIN = 2)
med <- as.data.frame(t(med))

## the Point estimate of the cumulative incidence function
cfm <- cif(main_res[[4]], main_res[[5]], cov = med)

## The point estimates of the baseline hazards from the main model
bch1 <- basehaz(main_res[[4]], centered = T)
bch2 <- basehaz(main_res[[5]], centered = T)
bchm1 <- bch1 * predict(main_res[[4]], med, "risk")
bchm2 <- bch2 * predict(main_res[[5]], med, "risk")
bchm1[,2] <- bch1[,2]
bchm2[,2] <- bch2[,2]
bchm1$id <- "repeated"
bchm2$id <- "recast"

######################################################################
#### Bootstrapping the baseline cause-specific hazards (Figure 2) ####
######################################################################

set.seed(1)

num_unit <- length(unique(compdat$pid))
boot_obj <- matrix(ncol = num_unit, nrow = 1000)
for (i in 1:1000) {
  boot_obj[i,] <- sample(unique(compdat$pid), num_unit, replace = T)
}
boot_result1 <- list()
length(boot_result1) <- 1000
boot_result2 <- list()
length(boot_result2) <- 1000
boot_med <- data.frame()

# The code here may take about 15 minutes

for (j in 1:nrow(boot_obj)) {
  id <- boot_obj[j,]
  di <- data.frame()
  ## Constructing bootstrapped sample
  for (k in 1:length(id)){
    di <- rbind(di, subset(compdat, compdat$pid == id[k]))
  }
  unii <- subset(di, di$t_time == 1)
  compi <- na.omit(unii[,which(names(unii) %in% c(names(main_res[[4]]$coefficients)))])
  
  ## Storing the median values of independent variables
  medi <- apply(compi, median, MARGIN = 2)
  medi <- as.data.frame(t(medi))
  boot_med <- rbind(boot_med, medi)
  
  ## Fit the competing risk model to the bootstraped sample
  resi <- competing_risks(di, indep = Full_Model, subest = F)
  boot_result1[[j]] <- resi[[4]]
  boot_result2[[j]] <- resi[[5]]
  print(paste(j, "/", nrow(boot_obj)))
}

## To save
#saveRDS(boot_result1, "boot_result1.R")
#saveRDS(boot_result2, "boot_result2.R")
#saveRDS(boot_med, "boot_med.R")

#boot_result1 <- readRDS("boot_result1.R")
#boot_result2 <- readRDS("boot_result2.R")
#boot_med <- readRDS("boot_med.R")

## A function to make a prediction for a specific independent variable
b_pred <- function(cov, cif = F, bchm1, bchm2, cfm){
  boot_ch1 <- data.frame(hazard = bchm1, time = bch1[,2])
  boot_ch2 <- data.frame(hazard = bchm2, time = bch2[,2])
  if (cif == T){
    boot_cf1 <- data.frame(chf = cfm[[1]], time = bch1[,2])
    boot_cf2 <- data.frame(chf = cfm[[2]], time = bch1[,2])
  }
  for (l in 1:length(boot_result1)) {
    if (cif == T){
      # CIFs for each bootstrapped sample
      cfm_l <- cif(boot_result1[[l]], boot_result2[[l]], cov)
    }
    # Baseline hazards for each bootstrapped sample
    basecumhaz1_l <- basehaz(boot_result1[[l]], centered = T)
    basecumhaz2_l <- basehaz(boot_result2[[l]], centered = T)
    # Cumulative hazards for each boostrapped sample
    cumhazmed1_l <- basecumhaz1_l * predict(boot_result1[[l]], cov, "risk")
    cumhazmed2_l <- basecumhaz2_l * predict(boot_result2[[l]], cov, "risk")
  
    cumhazmed1_l[,2] <- basecumhaz1_l[,2]
    cumhazmed2_l[,2] <- basecumhaz2_l[,2]
    if (cif == T) {
      cfm1_l <- data.frame(chf = cfm_l[[1]], time = basecumhaz1_l[,2])
      cfm2_l <- data.frame(chf = cfm_l[[2]], time = basecumhaz2_l[,2])
    }
    tomax <-  max(boot_ch1[,2]) - max(cumhazmed1_l[,2])
    if (tomax != 0) {
      for (m in 1:tomax) {
        cumhazmed1_l <- rbind(cumhazmed1_l, cumhazmed1_l[nrow(cumhazmed1_l),])
        cumhazmed2_l <- rbind(cumhazmed2_l, cumhazmed2_l[nrow(cumhazmed2_l),])
        if (cif == T) {
          cfm1_l <- rbind(cfm1_l, cfm1_l[nrow(cfm1_l),])
          cfm2_l <- rbind(cfm2_l, cfm2_l[nrow(cfm2_l),])
        }
      }
    }
  boot_ch1 <- rbind(boot_ch1, cumhazmed1_l)
  boot_ch2 <- rbind(boot_ch2, cumhazmed2_l)
  if (cif == T){
    boot_cf1 <- rbind(boot_cf1, cfm1_l)
    boot_cf2 <- rbind(boot_cf2, cfm2_l)
   }
  }
  if (cif == T) {
    return(list(boot_ch1, boot_ch2, boot_cf1, boot_cf2))
  } else {
    return(list(boot_ch1, boot_ch2))
  }
}


## Hazards and cumulative incidents fixing covariates at the median
pred <- b_pred(cov = med, cif = T, bchm1 = bchm1[,1], bchm2 = bchm2[,1], cfm = cfm)
boot_ch1 <- pred[[1]]
boot_ch2 <- pred[[2]]
boot_cf1 <- pred[[3]]
boot_cf2 <- pred[[4]]

haz1_int <- conf_int(boot_ch1)
haz2_int <- conf_int(boot_ch2)

cif1_int <- conf_int(boot_cf1)
cif2_int <- conf_int(boot_cf2)

point_haz <- rbind(bchm1, bchm2)
point_haz$type <- c(rep("repeated", nrow(bchm1)), rep("recast", nrow(bchm2)))

pmat <- data.frame(time = haz1_int$time, haz1p = bchm1[,1], low = haz1_int$low, haz1high = haz1_int$high,
                   haz2p = bchm2[,1], haz2low = haz2_int$low, haz2high = haz2_int$high)
matplot(x = pmat$time, y = pmat[,-1], type = "s",
        col = c("blue", "blue", "blue", "red", "red", "red"),
        lty = c(1, 2, 2, 1, 2, 2),
        xlab = "time", ylab = "estimated hazard", xlim = c(0, 40))
legend("topleft", legend = c("repeated", "recast"),
       col = c("blue", "red"), lty = 1)

pmatci <- data.frame(time = cif1_int$time, cif1p = cfm[[1]], cif1low = cif1_int$low, cif1high = cif1_int$high,
                   cif2p = cfm[[2]], cif2low = cif2_int$low, cif2high = cif2_int$high)

##################################################################
#### Cumulative Incidence Function of the Main Model Figure 2 ####
##################################################################
## to save
#png("ECIF_CF.png", width = 800, height = 600)
matplot(x = pmatci$time, y = pmatci[,-1], type = "s",
        col = c("blue", "blue", "blue", "red", "red", "red"),
        lty = c(1, 2, 2, 1, 2, 2), xlim = c(0, 40),
        xlab = "time", ylab = "Estimated Cumulative Incidence Function",
        cex.lab = 1.25)
text(38, .55, labels = "Repeated", cex = 1.25, col = 4)
text(38, .22, labels = "Recast", cex = 1.25, col = 2)
#dev.off()

### Prediction

## A function to make a dataframe for prediction for specific values of specific variables
covmake <- function(var, val_point) {
  cov_rels <- list()
  length(cov_rels) <- length(val_point)
  for (p in 1:length(val_point)) {
    cova <- med
    cova[,var] <- val_point[p]
    bch1_v <- bch1[,1] * predict(main_res[[4]], cova, "risk")
    bch2_v <- bch2[,1] * predict(main_res[[5]], cova, "risk")
    cfm_v <- cif(main_res[[4]], main_res[[5]], cova)
    cov_rels[[p]] <- list(cova, bch1_v, bch2_v, cfm_v)
  }
  return(cov_rels)
}

## Variables and their values to be used for prediction 
pred_vars <- list(
  pred_pko <- covmake("PKO", val_point = c(0,1)),
  pred_territory <- covmake("Revolution", val_point = c(0,1)),
  pred_actor <- covmake("Num_Actor", val_point = c(1,2)),
  pred_victory <- covmake("Victory", val_point = c(0,1)),
  pred_agreement <- covmake("Agreement", val_point = c(0,1)),
  pred_growth <- covmake("lag_Growth", val_point = c(quantile(first_cov$lag_Growth, c(0.25, 0.75)))),
  pred_imr <- covmake("lag_IMR", val_point = c(quantile(first_cov$lag_IMR, c(0.25, 0.75)))),
  pred_elf <- covmake("ELF", val_point = c(quantile(first_cov$ELF, c(0.25, 0.75)))),
  pred_res <- covmake("Res_Conf", val_point = c(0,1)),
  pred_mou <- covmake("Mountain", val_point = c(0,1)),
  pred_cw <- covmake("Post_CW", val_point = c(0,1))
)

allresult <- list()
length(allresult) <- length(pred_vars)

## Make prediction about the CIF and cumulative hazards for each variable, using bootstrapped baseline hazards

for (q in 1:length(allresult)) {
  thisvar <- pred_vars[[q]]
  thisvarout <- list()
  length(thisvarout) <- length(thisvar)
  for (r in 1:length(thisvar)) {
    thisboot <- b_pred(cov = thisvar[[r]][[1]], cif = T, bchm1 = thisvar[[r]][[2]],
                       bchm2 = thisvar[[r]][[3]], cfm = thisvar[[r]][[4]])
    thisvarout[[r]] <- thisboot
  }
  allresult[[q]] <- list(q, thisvarout)
  print(paste(q, "/", length(allresult)))
}

#saveRDS(allresult, "prediction.R")                                                                                                  

allresult <- readRDS("prediction.R")

## A function to plot the effect of variables on each hazard or CIF
plotmatrix <- function(varres, func = "haz", cause = 1) {
  vars <- varres[[2]]
  plotmat <- data.frame(time = bch1[,2])
  bango <- varres[[1]]
  for (s in 1:length(vars)) {
    if (func == "haz") {
      oneval <- conf_int(vars[[s]][[cause]])
      oneval$point <- pred_vars[[bango]][[s]][[1+cause]]
    } else {
      oneval <- conf_int(vars[[s]][[2+cause]])
      oneval$point <- pred_vars[[bango]][[s]][[4]][[cause]]
    }
    oneval <- oneval[,-3]
    names(oneval) <- c(paste("low", s, sep = ""), paste("high", s, sep = ""), paste("point", s, sep = ""))
    plotmat <- cbind(plotmat, oneval)
  }
  return(plotmat)
}

######################################################
#### Effects of each variable on CIF (Appendix D) ####
######################################################

## Number of groups (Figure 6)

plotnum1 <- plotmatrix(varres = allresult[[3]], func = "cif", cause = 1)

matplot(x = plotnum1$time, y = plotnum1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Num_Actor=1", "Num_Actor=2"),
       col = c("black", "red"), lty = 1)

plotnum2 <- plotmatrix(varres = allresult[[3]], func = "cif", cause = 2)

matplot(x = plotnum2$time, y = plotnum2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Num_Actor=1", "Num_Actor=2"),
       col = c("black", "red"), lty = 1)

## Revolution (Figure 7)

plotrev1 <- plotmatrix(varres = allresult[[2]], func = "cif", cause = 1)

matplot(x = plotrev1$time, y = plotrev1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Revolution=0", "Revolution=1"),
       col = c("black", "red"), lty = 1)

plotrev2 <- plotmatrix(varres = allresult[[2]], func = "cif", cause = 2)

matplot(x = plotrev2$time, y = plotrev2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Revolution=0", "Revolution=1"),
       col = c("black", "red"), lty = 1)

## Resource conflict (Figure 8)

plotres1 <- plotmatrix(varres = allresult[[9]], func = "cif", cause = 1)

matplot(x = plotres1$time, y = plotres1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Res_Conf=0", "Res_Conf=1"),
       col = c("black", "red"), lty = 1)

plotres2 <- plotmatrix(varres = allresult[[9]], func = "cif", cause = 2)

matplot(x = plotres2$time, y = plotres2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Res_Conf=0", "Res_Conf=1"),
       col = c("black", "red"), lty = 1)

## PKO (Figure 9)

plotpko1 <- plotmatrix(varres = allresult[[1]], func = "cif", cause = 1)

matplot(x = plotpko1$time, y = plotpko1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("PKO=0", "PKO=1"),
       col = c("black", "red"), lty = 1)

plotpko2 <- plotmatrix(varres = allresult[[1]], func = "cif", cause = 2)

matplot(x = plotpko2$time, y = plotpko2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("PKO=0", "PKO=1"),
       col = c("black", "red"), lty = 1)

## Mountain (Figure 10)

plotmou1 <- plotmatrix(varres = allresult[[10]], func = "cif", cause = 1)

matplot(x = plotmou1$time, y = plotmou1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Mountain=0", "Mountain=1"),
       col = c("black", "red"), lty = 1)

plotmou2 <- plotmatrix(varres = allresult[[10]], func = "cif", cause = 2)

matplot(x = plotmou2$time, y = plotmou2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Mountain=0", "Mountain=1"),
       col = c("black", "red"), lty = 1)

## IMR (Figure 11)

quantile(first_cov$lag_IMR, c(0.25, 0.75))

plotimr1 <- plotmatrix(varres = allresult[[7]], func = "cif", cause = 1)

matplot(x = plotimr1$time, y = plotimr1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2,2,1,2,2,1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("lag_IMR=43.20", "lag_IMR=96.65"),
       col = c("black", "red"), lty = 1)

plotimr2 <- plotmatrix(varres = allresult[[7]], func = "cif", cause = 2)

matplot(x = plotimr2$time, y = plotimr2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2,2,1,2,2,1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("lag_IMR=43.20", "lag_IMR=96.65"),
       col = c("black", "red"), lty = 1)


## Peace agreement (Figure 12)

plotpa1 <- plotmatrix(varres = allresult[[5]], func = "cif", cause = 1)

matplot(x = plotpa1$time, y = plotpa1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Agreement=0", "Agreement=1"),
       col = c("black", "red"), lty = 1)

plotpa2 <- plotmatrix(varres = allresult[[5]], func = "cif", cause = 2)

matplot(x = plotpa2$time, y = plotpa2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Agreement=0", "Agreement=1"),
       col = c("black", "red"), lty = 1)


## ELF (Figure 13)

quantile(first_cov$ELF, c(0.25, 0.75))

plotelf1 <- plotmatrix(varres = allresult[[8]], func = "cif", cause = 1)
matplot(x = plotelf1$time, y = plotelf1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("ELF=0.50", "ELF=0.77"),
       col = c("black", "red"), lty = 1)

plotelf2 <- plotmatrix(varres = allresult[[8]], func = "cif", cause = 2)
matplot(x = plotelf2$time, y = plotelf2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("ELF=0.50", "ELF=0.77"),
       col = c("black", "red"), lty = 1)


## Victory (Figure 14)

plotvic1 <- plotmatrix(varres = allresult[[4]], func = "cif", cause = 1)

matplot(x = plotvic1$time, y = plotvic1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Victory=0", "Victory=1"),
       col = c("black", "red"), lty = 1)

plotvic2 <- plotmatrix(varres = allresult[[4]], func = "cif", cause = 2)

matplot(x = plotvic2$time, y = plotvic2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Victory=0", "Victory=1"),
       col = c("black", "red"), lty = 1)

## Growth (Figure 15)

quantile(first_cov$lag_Growth, c(0.25, 0.75))

plotgr1 <- plotmatrix(varres = allresult[[6]], func = "cif", cause = 1)

matplot(x = plotgr1$time, y = plotgr1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2,2,1,2,2,1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("lag_Growth=-0.01", "lag_Growth=0.08"),
       col = c("black", "red"), lty = 1)

plotgr2 <- plotmatrix(varres = allresult[[6]], func = "cif", cause = 2)

matplot(x = plotgr2$time, y = plotgr2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2,2,1,2,2,1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("lag_Growth=-0.01", "lag_Growth=0.08"),
       col = c("black", "red"), lty = 1)


## Cold War (Figure 16)

plotcw1 <- plotmatrix(varres = allresult[[11]], func = "cif", cause = 1)

matplot(x = plotcw1$time, y = plotcw1[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for repeated conflicts", xlim = c(0, 40))
legend("bottomright", legend = c("Post CW=0", "Post CW=1"),
       col = c("black", "red"), lty = 1)

plotcw2 <- plotmatrix(varres = allresult[[11]], func = "cif", cause = 2)

matplot(x = plotcw2$time, y = plotcw2[,-1], type = "s",
        col = c("black", "black", "black", "red", "red", "red"),
        lty = c(2, 2, 1, 2, 2, 1),
        xlab = "time", ylab = "CIF for recast conflicts", xlim = c(0, 40))
legend("topleft", legend = c("Post CW=0", "Post CW=1"),
       col = c("black", "red"), lty = 1)

#########################
#### End of the Code ####
#########################